{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cdbe4197",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['CUB', 'Derm7pt', 'RIVAL10']\n",
    "use_dataset = datasets[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c415bffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.config import CUB_CONFIG, DERM7PT_CONFIG, RIVAL10_CONFIG  # noqa: E402\n",
    "from src.config import PROJECT_ROOT  # noqa: E402\n",
    "import numpy as np  # noqa: E402"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc4dd3dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_dataset == 'CUB':\n",
    "    config_dict = CUB_CONFIG\n",
    "    DATASET_PATH =  os.path.join(PROJECT_ROOT, 'output', 'CUB')\n",
    "elif use_dataset == 'Derm7pt':\n",
    "    config_dict = DERM7PT_CONFIG\n",
    "    DATASET_PATH =  os.path.join(PROJECT_ROOT, 'output', 'Derm7pt')\n",
    "else:\n",
    "    config_dict = RIVAL10_CONFIG\n",
    "    DATASET_PATH =  os.path.join(PROJECT_ROOT, 'output', 'RIVAL10')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffe4b59a",
   "metadata": {},
   "source": [
    "# Load and Transform Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc1bbd44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# INSTANCE-BASED CUB MODEL\n",
    "\n",
    "# C_train = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_train_instance.npy'))\n",
    "# C_hat_train = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_hat_sigmoid_train_instance.npy'))\n",
    "# one_hot_Y_train = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'Y_train_instance.npy'))\n",
    "\n",
    "# C_test = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_test_instance.npy'))\n",
    "# C_hat_test = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'C_hat_sigmoid_test_instance.npy'))\n",
    "# one_hot_Y_test = np.load(os.path.join(PROJECT_ROOT, 'output', 'CUB', 'Y_test_instance.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b037fd34",
   "metadata": {},
   "outputs": [],
   "source": [
    "C_hat_train = np.load(os.path.join(DATASET_PATH, 'C_hat_sigmoid_train.npy'))\n",
    "one_hot_Y_train = np.load(os.path.join(DATASET_PATH, 'Y_train.npy'))\n",
    "\n",
    "C_hat_test = np.load(os.path.join(DATASET_PATH, 'C_hat_sigmoid_test.npy'))\n",
    "one_hot_Y_test = np.load(os.path.join(DATASET_PATH, 'Y_test.npy'))\n",
    "\n",
    "if use_dataset == 'Derm7pt':\n",
    "    C_hat_val = np.load(os.path.join(DATASET_PATH, 'C_hat_sigmoid_val.npy'))\n",
    "    one_hot_Y_val = np.load(os.path.join(DATASET_PATH, 'Y_val.npy'))\n",
    "\n",
    "    C_hat_train = np.concatenate((C_hat_train, C_hat_val), axis=0)\n",
    "    one_hot_Y_train = np.concatenate((one_hot_Y_train, one_hot_Y_val), axis=0)\n",
    "\n",
    "class_level_concepts = np.load(os.path.join(DATASET_PATH, 'class_level_concepts.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dc3226bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_train = np.argmax(one_hot_Y_train, axis=1)\n",
    "Y_test = np.argmax(one_hot_Y_test, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "85ad0c07",
   "metadata": {},
   "outputs": [],
   "source": [
    "C_train = []\n",
    "for y in Y_train:\n",
    "    C_train.append(class_level_concepts[y])\n",
    "\n",
    "C_train = np.array(C_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "47193d3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils import shuffle\n",
    "\n",
    "C_hat_train, C_train, one_hot_Y_train, Y_train = shuffle(C_hat_train, C_train, one_hot_Y_train, Y_train, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cae5dd1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# unique, counts = np.unique(Y_train, return_counts=True)\n",
    "# for label, count in zip(unique, counts):\n",
    "#     print(f\"Label {label}: {count} instances\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3a0a2d1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# unique, counts = np.unique(Y_test, return_counts=True)\n",
    "# for label, count in zip(unique, counts):\n",
    "#     print(f\"Label {label}: {count} instances\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb99f9c0",
   "metadata": {},
   "source": [
    "# Classic Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0424779b",
   "metadata": {},
   "source": [
    "## Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7a6ed058",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logistic Regression Test accuracy: 0.6632911392405063\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "lin_model = LogisticRegression(max_iter=1000)\n",
    "lin_model.fit(C_hat_train, Y_train)\n",
    "print(f\"Logistic Regression Test accuracy: {lin_model.score(C_hat_test, Y_test)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "43484e2a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(lin_model.predict(C_hat_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0d5bbd2",
   "metadata": {},
   "source": [
    "Save incorrectly classified instances for intervention experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aa5d5d99",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(133, 19)\n",
      "(133,)\n"
     ]
    }
   ],
   "source": [
    "Y_pred = lin_model.predict(C_hat_test)\n",
    "wrong_indices = np.where(Y_test != Y_pred)[0]\n",
    "\n",
    "C_hat_wrong = C_hat_test[wrong_indices]\n",
    "Y_wrong = Y_test[wrong_indices]\n",
    "\n",
    "print(C_hat_wrong.shape)\n",
    "print(Y_wrong.shape)\n",
    "\n",
    "output_dir = os.path.join(PROJECT_ROOT, 'output', 'intervention', use_dataset)\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "np.save(os.path.join(output_dir, 'C_hat_linear.npy'), C_hat_wrong)\n",
    "np.save(os.path.join(output_dir, 'Y_linear.npy'), Y_wrong)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d34ca1da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['/Users/pb/Documents/career/lab_ujm/hybrid-cbm-prototype-model/models/Derm7pt/lin_model.joblib']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import joblib\n",
    "model_dir = os.path.join(PROJECT_ROOT, 'models', use_dataset)\n",
    "os.makedirs(model_dir, exist_ok=True)\n",
    "\n",
    "joblib.dump(lin_model, os.path.join(model_dir, 'lin_model.joblib'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "973410e7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.41021713, 0.40787597, 0.40497543, 0.3868018 , 0.42114473])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# np.mean(np.abs(lin_model.coef_), axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "419e627c",
   "metadata": {},
   "source": [
    "## k-NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3ec9af79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k-NN Test accuracy: 0.6\n"
     ]
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "model = KNeighborsClassifier()\n",
    "model.fit(C_hat_train, Y_train)\n",
    "print(f\"k-NN Test accuracy: {model.score(C_hat_test, Y_test)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e442a104",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(model.predict(C_hat_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a14b904",
   "metadata": {},
   "source": [
    "## Decision Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2e8610c3-c78b-4997-9055-10af7ccf5958",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decision Tree Test accuracy: 0.5670886075949367\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "model = DecisionTreeClassifier()\n",
    "model.fit(C_hat_train, Y_train)\n",
    "print(f\"Decision Tree Test accuracy: {model.score(C_hat_test, Y_test)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "466a67ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(model.predict(C_hat_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aff22cb",
   "metadata": {},
   "source": [
    "## MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6f9c3248",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP Test accuracy: 0.660759493670886\n"
     ]
    }
   ],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "mlp = MLPClassifier(hidden_layer_sizes=(512,256, 128), max_iter=1000)\n",
    "mlp.fit(C_hat_train, Y_train)\n",
    "print(f\"MLP Test accuracy: {mlp.score(C_hat_test, Y_test)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9bd4c37b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(model.predict(C_hat_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5950cfd",
   "metadata": {},
   "source": [
    "# Accuracy Using Class-Level Concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ed1991c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy using absolute difference to class concepts: 0.6481\n"
     ]
    }
   ],
   "source": [
    "# Calculate differences between each test instance and all class-level concepts\n",
    "distances = []\n",
    "\n",
    "for i, test_instance in enumerate(C_hat_test):\n",
    "    # Calculate absolute differences to each class concept\n",
    "    # This gives element-wise differences between probabilities and binary values\n",
    "    instance_diffs = np.abs(test_instance - class_level_concepts)\n",
    "\n",
    "    # Sum the differences along concept dimension to get total deviation for each class\n",
    "    instance_distances = np.sum(instance_diffs, axis=1)\n",
    "\n",
    "    # Find the minimum distance\n",
    "    min_distance = np.min(instance_distances)\n",
    "    # Find which class has the minimum distance\n",
    "    min_class = np.argmin(instance_distances)\n",
    "    distances.append((min_distance, min_class))\n",
    "\n",
    "# Convert to numpy arrays for easier analysis\n",
    "min_distances = np.array([d[0] for d in distances])\n",
    "predicted_classes = np.array([d[1] for d in distances])\n",
    "\n",
    "# Calculate accuracy\n",
    "accuracy = np.mean(predicted_classes == Y_test)\n",
    "print(f\"Accuracy using absolute difference to class concepts: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d9dc649-e4d0-40d5-aa23-bd8b5360bea6",
   "metadata": {},
   "source": [
    "# Prototype-Based Model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d6be3f96",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "118ea801",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14da7dda",
   "metadata": {},
   "source": [
    "## Create Dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "482ea172",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_split_ratio = 0.2\n",
    "random_seed = 42\n",
    "\n",
    "if use_dataset == 'Derm7pt':\n",
    "    X_train = torch.tensor(C_hat_train, dtype=torch.float32)\n",
    "    Y_train = torch.tensor(one_hot_Y_train, dtype=torch.float32)\n",
    "else:\n",
    "    # C_hat_train, C_hat_val, Y_train_np, Y_val_np = train_test_split(C_hat_train, one_hot_Y_train, test_size=val_split_ratio, random_state=random_seed)\n",
    "    X_train = torch.tensor(C_hat_train, dtype=torch.float32)\n",
    "    Y_train = torch.tensor(one_hot_Y_train, dtype=torch.float32)\n",
    "\n",
    "X_test = torch.tensor(C_hat_test, dtype=torch.float32, device=device)\n",
    "Y_test = torch.tensor(one_hot_Y_test, dtype=torch.float32, device=device)\n",
    "\n",
    "# DATALOADERS\n",
    "batch_size = 64\n",
    "train_dataset = TensorDataset(X_train, Y_train)\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "test_dataset = TensorDataset(X_test, Y_test)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "497a8c4e",
   "metadata": {},
   "source": [
    "## Learn Prototypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "eb8f8f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import PrototypeClassifier\n",
    "\n",
    "num_concepts = config_dict['N_TRIMMED_CONCEPTS']\n",
    "num_classes = config_dict['N_CLASSES']\n",
    "\n",
    "model = PrototypeClassifier(num_concepts, num_classes).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "lambda_binary = 0.01\n",
    "lambda_L1 = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c7f50626",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[26], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# train and test\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtqdm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtraining\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_epoch\n\u001b[1;32m      5\u001b[0m num_epochs \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m50\u001b[39m\n\u001b[1;32m      6\u001b[0m best_acc, best_epoch \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m\n",
      "File \u001b[0;32m~/Documents/career/lab_ujm/hybrid-cbm-prototype-model/src/training/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mphase_1_training\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mprototype_training\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_epoch, val_epoch\n\u001b[1;32m      4\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrun_epoch_x_to_c\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrun_epoch_x_to_y\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_epoch\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_epoch\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "File \u001b[0;32m~/Documents/career/lab_ujm/hybrid-cbm-prototype-model/src/training/phase_1_training.py:4\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtqdm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tqdm\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AverageMeter, binary_accuracy\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_get_outputs\u001b[39m(model, inputs):\n\u001b[1;32m      7\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m model(inputs)\n",
      "File \u001b[0;32m~/Documents/career/lab_ujm/hybrid-cbm-prototype-model/src/utils/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mhelpers\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m vprint, get_filename_to_id_mapping, load_concept_names, find_class_imbalance, get_paths, load_Derm_dataset\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AverageMeter, binary_accuracy, accuracy\n\u001b[1;32m      4\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvprint\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mget_filename_to_id_mapping\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mload_concept_names\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfind_class_imbalance\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m      5\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAverageMeter\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary_accuracy\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124maccuracy\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mget_paths\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mload_Derm_dataset\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "File \u001b[0;32m~/Documents/career/lab_ujm/hybrid-cbm-prototype-model/src/utils/helpers.py:7\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconfig\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m PROJECT_ROOT\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mderm7pt\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdataset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Derm7PtDatasetGroupInfrequent\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mvprint\u001b[39m(message, is_verbose):\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_verbose:\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/derm7pt/dataset.py:5\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpreprocessing\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mimage\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m load_img\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mderm7pt\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m strings2numeric\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/__init__.py:7\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;124;03m\"\"\"DO NOT EDIT.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \n\u001b[1;32m      3\u001b[0m \u001b[38;5;124;03mThis file was autogenerated. Do not edit it by hand,\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;124;03msince your modifications would be overwritten.\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _tf_keras \u001b[38;5;28;01mas\u001b[39;00m _tf_keras\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m activations \u001b[38;5;28;01mas\u001b[39;00m activations\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m applications \u001b[38;5;28;01mas\u001b[39;00m applications\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/_tf_keras/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_tf_keras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m keras\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/_tf_keras/keras/__init__.py:7\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;124;03m\"\"\"DO NOT EDIT.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \n\u001b[1;32m      3\u001b[0m \u001b[38;5;124;03mThis file was autogenerated. Do not edit it by hand,\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;124;03msince your modifications would be overwritten.\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m activations \u001b[38;5;28;01mas\u001b[39;00m activations\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m applications \u001b[38;5;28;01mas\u001b[39;00m applications\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m callbacks \u001b[38;5;28;01mas\u001b[39;00m callbacks\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/activations/__init__.py:7\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;124;03m\"\"\"DO NOT EDIT.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \n\u001b[1;32m      3\u001b[0m \u001b[38;5;124;03mThis file was autogenerated. Do not edit it by hand,\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;124;03msince your modifications would be overwritten.\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m deserialize \u001b[38;5;28;01mas\u001b[39;00m deserialize\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get \u001b[38;5;28;01mas\u001b[39;00m get\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m serialize \u001b[38;5;28;01mas\u001b[39;00m serialize\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m activations\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m applications\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m backend\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/activations/__init__.py:3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtypes\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m celu\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m elu\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mactivations\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m exponential\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/activations/activations.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m backend\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ops\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mapi_export\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m keras_export\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/backend/__init__.py:10\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mapi_export\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m keras_export\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdtypes\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m result_type\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mkeras_tensor\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m KerasTensor\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mkeras_tensor\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m any_symbolic_tensors\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/backend/common/__init__.py:2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m backend_utils\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdtypes\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m result_type\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvariables\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AutocastScope\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvariables\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Variable \u001b[38;5;28;01mas\u001b[39;00m KerasVariable\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/backend/common/dtypes.py:5\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mapi_export\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m keras_export\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m config\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvariables\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m standardize_dtype\n\u001b[1;32m      7\u001b[0m BOOL_TYPES \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbool\u001b[39m\u001b[38;5;124m\"\u001b[39m,)\n\u001b[1;32m      8\u001b[0m INT_TYPES \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m      9\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muint8\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muint16\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint64\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     17\u001b[0m )\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/backend/common/variables.py:11\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mstateless_scope\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_stateless_scope\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcommon\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mstateless_scope\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m in_stateless_scope\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodule_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tensorflow \u001b[38;5;28;01mas\u001b[39;00m tf\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnaming\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m auto_name\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mVariable\u001b[39;00m:\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/utils/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maudio_dataset_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m audio_dataset_from_directory\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdataset_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m split_dataset\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfile_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_file\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/utils/audio_dataset_utils.py:4\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mapi_export\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m keras_export\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m dataset_utils\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodule_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tensorflow \u001b[38;5;28;01mas\u001b[39;00m tf\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodule_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tensorflow_io \u001b[38;5;28;01mas\u001b[39;00m tfio\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/utils/dataset_utils.py:9\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mmultiprocessing\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpool\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ThreadPool\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m tree\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mapi_export\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m keras_export\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m io_utils\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/tree/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree_api\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m assert_same_paths\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree_api\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m assert_same_structure\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree_api\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m flatten\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/tree/tree_api.py:8\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodule_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m optree\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m optree\u001b[38;5;241m.\u001b[39mavailable:\n\u001b[0;32m----> 8\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m optree_impl \u001b[38;5;28;01mas\u001b[39;00m tree_impl\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m dmtree\u001b[38;5;241m.\u001b[39mavailable:\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mkeras\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtree\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m dmtree_impl \u001b[38;5;28;01mas\u001b[39;00m tree_impl\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/keras/src/tree/optree_impl.py:13\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m# Register backend-specific node classes\u001b[39;00m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m backend() \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtensorflow\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 13\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtrackable\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_structures\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ListWrapper\n\u001b[1;32m     14\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtrackable\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata_structures\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _DictWrapper\n\u001b[1;32m     16\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/tensorflow/__init__.py:59\u001b[0m\n\u001b[1;32m     57\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m data\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m debugging\n\u001b[0;32m---> 59\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m distribute\n\u001b[1;32m     60\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m dtypes\n\u001b[1;32m     61\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m errors\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/tensorflow/_api/v2/distribute/__init__.py:9\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01msys\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m_sys\u001b[39;00m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistribute\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m cluster_resolver\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistribute\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m coordinator\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_api\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mv2\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistribute\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m experimental\n\u001b[1;32m     11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtensorflow\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistribute\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcollective_all_reduce_strategy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m CollectiveAllReduceStrategy \u001b[38;5;28;01mas\u001b[39;00m MultiWorkerMirroredStrategy \u001b[38;5;66;03m# line: 56\u001b[39;00m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:230\u001b[0m, in \u001b[0;36m_lock_unlock_module\u001b[0;34m(name)\u001b[0m\n",
      "File \u001b[0;32m<frozen importlib._bootstrap>:125\u001b[0m, in \u001b[0;36mrelease\u001b[0;34m(self)\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# train and test\n",
    "from tqdm import tqdm\n",
    "from src.training import train_epoch\n",
    "\n",
    "num_epochs = 50\n",
    "best_acc, best_epoch = 0, 0\n",
    "\n",
    "tqdm_loader = tqdm(range(num_epochs), desc=\"Training Prototypes\", leave=True)\n",
    "for epoch in tqdm_loader:\n",
    "    train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, lambda_binary, lambda_L1, device=device)\n",
    "    if train_accuracy > best_acc:\n",
    "        best_acc = train_accuracy\n",
    "        best_epoch = epoch\n",
    "    tqdm_loader.set_postfix({\"Train Acc\": f\"{train_accuracy:.2f}%\", \"Train Loss\": f\"{train_loss:.4f}\"})\n",
    "\n",
    "print(f\"Best accuracy of {best_acc:.2f}% achieved at epoch {best_epoch}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b82dc36",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9895951570185395"
      ]
     },
     "execution_count": 224,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "real_labels = Y_test.argmax(dim=1)\n",
    "predictions = model.predict(X_test)\n",
    "(predictions == real_labels).sum().item()/len(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c674a5b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([55, 18])\n",
      "torch.Size([55])\n"
     ]
    }
   ],
   "source": [
    "wrong_indices = (predictions != real_labels).nonzero(as_tuple=True)[0]\n",
    "\n",
    "C_hat_wrong = X_test[wrong_indices]\n",
    "Y_wrong = real_labels[wrong_indices]\n",
    "\n",
    "print(C_hat_wrong.shape)\n",
    "print(Y_wrong.shape)\n",
    "\n",
    "output_dir = os.path.join(PROJECT_ROOT, 'output', 'intervention', use_dataset)\n",
    "\n",
    "torch.save(C_hat_wrong, os.path.join(output_dir, 'C_hat_clc.npy'))\n",
    "torch.save(Y_wrong, os.path.join(output_dir, 'Y_clc.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ecabf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the model\n",
    "torch.save(model, os.path.join(model_dir, 'clc_model.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d85fa563",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.unique(predictions.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d004525d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00       533\n",
      "           1       0.99      0.99      0.99       534\n",
      "           2       1.00      0.99      0.99       532\n",
      "           3       0.99      0.99      0.99       530\n",
      "           4       0.99      1.00      0.99       531\n",
      "           5       0.99      0.99      0.99       533\n",
      "           6       0.99      1.00      1.00       530\n",
      "           7       0.99      0.98      0.98       529\n",
      "           8       0.99      1.00      0.99       532\n",
      "           9       0.97      0.98      0.98       502\n",
      "\n",
      "    accuracy                           0.99      5286\n",
      "   macro avg       0.99      0.99      0.99      5286\n",
      "weighted avg       0.99      0.99      0.99      5286\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "y_true = real_labels.cpu().numpy()\n",
    "y_pred = predictions.cpu().numpy()\n",
    "\n",
    "print(classification_report(y_true, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade5c6e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8035714626312256% of the values are close to 0 or 1\n"
     ]
    }
   ],
   "source": [
    "close_to_zero = (torch.sum((model.get_sigmoid_prototypes() < 0.1) | (model.get_sigmoid_prototypes() > 0.9)) / (200*112)).cpu().numpy()\n",
    "print(f\"{close_to_zero*100}% of the values are close to 0 or 1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f6cf8c9",
   "metadata": {},
   "source": [
    "# Class-level vs Learned"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2985a2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall sparsity: 0.7500\n"
     ]
    }
   ],
   "source": [
    "# print(class_level_concepts)\n",
    "\n",
    "# Overall sparsity (fraction of zeros)\n",
    "overall_sparsity = np.mean(class_level_concepts == 0)\n",
    "print(f\"Overall sparsity: {overall_sparsity:.4f}\")\n",
    "\n",
    "# # Sparsity per row (fraction of zeros in each row)\n",
    "# row_sparsity = np.mean(class_level_concepts == 0, axis=1)\n",
    "# print(\"Sparsity per row:\", row_sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c8ae91e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall sparsity: 0.7500\n"
     ]
    }
   ],
   "source": [
    "Prototypes = model.get_binary_prototypes()\n",
    "Prototypes = Prototypes.cpu().detach().numpy()\n",
    "# print(Prototypes)\n",
    "\n",
    "# Overall sparsity (fraction of zeros)\n",
    "overall_sparsity = np.mean(Prototypes == 0)\n",
    "print(f\"Overall sparsity: {overall_sparsity:.4f}\")\n",
    "\n",
    "# Sparsity per row (fraction of zeros in each row)\n",
    "# row_sparsity = np.mean(Prototypes == 0, axis=1)\n",
    "# print(\"Sparsity per row:\", row_sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e88c96",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prototypes = model.get_sigmoid_prototypes()\n",
    "# Prototypes = Prototypes.cpu().detach().numpy()\n",
    "# print(Prototypes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a7dc8f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0% of the values are close to 0.5\n"
     ]
    }
   ],
   "source": [
    "close_to_half = (torch.sum((model.get_sigmoid_prototypes() > 0.4) & (model.get_sigmoid_prototypes() < 0.6)) / (200*112)).cpu().numpy()\n",
    "print(f\"{close_to_half*100}% of the values are close to 0.5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38c33e67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],\n",
       "       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],\n",
       "       [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],\n",
       "       [1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],\n",
       "       [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],\n",
       "       [1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1],\n",
       "       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1],\n",
       "       [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1],\n",
       "       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0],\n",
       "       [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0]])"
      ]
     },
     "execution_count": 234,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_level_concepts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "658c8f7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1.]\n",
      " [1. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1.]\n",
      " [1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n",
      " [1. 0. 0. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1.]\n",
      " [0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0.]\n",
      " [0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0.]]\n"
     ]
    }
   ],
   "source": [
    "print(model.get_binary_prototypes().cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51a31bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inverse_sigmoid(x):\n",
    "    epsilon = 1e-7\n",
    "    x = np.clip(x, epsilon, 1 - epsilon)\n",
    "    return np.log(x / (1 - x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82772330",
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_dataset == 'Derm7pt':\n",
    "    new_prototypes = model.get_sigmoid_prototypes().cpu().detach().numpy()\n",
    "    # Apply inverse sigmoid to convert binary class-level concepts to logits\n",
    "    for i in range(5):\n",
    "        new_prototypes[i] = inverse_sigmoid(class_level_concepts[i])\n",
    "\n",
    "    new_prototypes = torch.tensor(new_prototypes, device=device, dtype=torch.float32)\n",
    "    print(new_prototypes)\n",
    "\n",
    "    model.modify_prototypes(new_prototypes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c75f75c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n",
       "        0., 1.],\n",
       "       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,\n",
       "        0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
       "        0., 1.],\n",
       "       [1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
       "        0., 1.],\n",
       "       [1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n",
       "        0., 0.],\n",
       "       [1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0.,\n",
       "        0., 1.],\n",
       "       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n",
       "        0., 1.],\n",
       "       [0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,\n",
       "        1., 1.],\n",
       "       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,\n",
       "        1., 0.],\n",
       "       [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1.,\n",
       "        1., 0.]], dtype=float32)"
      ]
     },
     "execution_count": 238,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_binary_prototypes().cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a47d378c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the model\n",
    "torch.save(model, os.path.join(model_dir, 'clc_model.pth'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61b09a36",
   "metadata": {},
   "source": [
    "# MY OLD CODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "091fbd97",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfe9091f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # --- Plotting ---\n",
    "# from matplotlib import pyplot as plt\n",
    "\n",
    "# plt.figure(figsize=(10, 5))\n",
    "# epochs_range = range(1, epochs + 1)\n",
    "# plt.plot(epochs_range, train_losses, label='Training Loss', marker='o', linestyle='-')\n",
    "# plt.plot(epochs_range, val_losses, label='Validation Loss', marker='x', linestyle='--')\n",
    "# plt.title('Training and Validation Loss Over Epochs')\n",
    "# plt.xlabel('Epoch')\n",
    "# plt.ylabel('Average Loss')\n",
    "# plt.legend()\n",
    "# plt.grid(True)\n",
    "# plt.show()\n",
    "\n",
    "# # Optional: Plot validation accuracy as well\n",
    "# plt.figure(figsize=(10, 5))\n",
    "# plt.plot(epochs_range, val_accuracies, label='Validation Accuracy', marker='s', linestyle='-', color='green')\n",
    "# plt.title('Validation Accuracy Over Epochs')\n",
    "# plt.xlabel('Epoch')\n",
    "# plt.ylabel('Accuracy (%)')\n",
    "# plt.legend()\n",
    "# plt.grid(True)\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a987f12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prototypes = []\n",
    "# for y in Y_train:\n",
    "#     prototypes.append(final_binary_prototypes[y])\n",
    "\n",
    "# prototypes = np.array(prototypes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daad5b1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Function to find the closest concept vector and predict the label\n",
    "# def predict_nearest_concept(instance, reference_concepts, reference_labels):\n",
    "#     distances = np.sum(np.abs(reference_concepts - instance), axis=1)\n",
    "#     min_idx = np.argmin(distances)\n",
    "#     return reference_labels[min_idx]\n",
    "\n",
    "# # Use prototypes as reference concepts and evaluate on C_hat_test\n",
    "# correct_predictions = 0\n",
    "# total_predictions = len(C_hat_test)\n",
    "\n",
    "# for i, test_instance in enumerate(C_hat_test):\n",
    "#     predicted_label = predict_nearest_concept(test_instance, prototypes, Y_train)\n",
    "#     true_label = Y_test[i]\n",
    "\n",
    "#     if predicted_label == true_label:\n",
    "#         correct_predictions += 1\n",
    "\n",
    "# # Calculate and print accuracy\n",
    "# accuracy = correct_predictions / total_predictions\n",
    "# print(f\"\\nOverall accuracy using prototype-based nearest neighbor: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8666fb86",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
